import random
from typing import Optional
import numpy as np
from tqdm import tqdm
from abc import ABC, abstractmethod
from copy import deepcopy
import sys
import torch

from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import confusion_matrix, f1_score, recall_score, precision_score, accuracy_score

from utils.utils import load_json, save_json
from evaluation.simple_answer_check import simple_answer_check
from evaluation.latex_answer_check import latex_answer_check

random.seed(0)

class BaseEvaluation(ABC):
    def __init__(self, dataset):
        self.dataset = dataset

    @abstractmethod
    def get_results(self):
        pass

class AnswerEvaluation(BaseEvaluation):
    def __init__(self, dataset: list, split_type: str=None):
        super().__init__(dataset)
        self.split_type = split_type
    
    def print_results(self, total_acc, each_avg_acc):
        print(f"Total Accuracy: {round(100 * total_acc / len(self.dataset), 2)}%")
        print(f"Avg Accuracy: {round(100 * each_avg_acc / len(self.dataset), 2)}%")
    
    def check_correctness(self, output, answer, config_for_type, split):
        """Refactored to handle exception and split logic."""
        if len(output) == 0:
            try:
                output = output.split("boxed")[1]
            except IndexError:
                return False
        try:
            return latex_answer_check(output, answer, eval_policy=config_for_type["eval_policy"], extract_policy=config_for_type["extract_policy"], split=split)
        except Exception as e:
            print(f"Error checking correctness: {e}")
            return False

    def get_results(self):
        config = load_json("src/evaluation/eval_config.json")
        data_type = "MATH"
        config_for_type = config["policy"][data_type]

        total_acc = 0
        each_avg_acc = 0
        for i, data in enumerate(tqdm(self.dataset)):
            correctness = []
            each_acc = 0
            outputs = data["output"]["answer"] if self.split_type == "extract_answer" else data["output"]["solution"]
            split = None if self.split_type == "extract_answer" else config["extract_pattern"]
            for output in outputs:
                result = self.check_correctness(output, data["answer"], config_for_type, split)
                each_acc += int(result)
                correctness.append(result)

            total_acc += int(each_acc > 0)
            each_avg_acc += each_acc / len(outputs)
            data["output"]['correctness'] = correctness

        self.print_results(total_acc, each_avg_acc)

        return self.dataset

class DiversityEvaluation(BaseEvaluation):
    def __init__(self, dataset: list):
        super().__init__(dataset)
        self.model = SentenceTransformer("Alibaba-NLP/gte-large-en-v1.5", trust_remote_code=True)
        self.cal_num = 0 

    def print_results(self, avg_diversity):
        for k, v in avg_diversity.items():
            print(f"Diversity ({k}): {round(v / self.cal_num, 3)}")

    def calculate_diversity(self, solutions):
        ## Self-Bert score (for diversity)
        with torch.no_grad():
            embeddings = self.model.encode(solutions, convert_to_tensor=True)
        diversity_matrix = util.cos_sim(embeddings, embeddings)
        num_pairs = len(solutions) * (len(solutions) - 1) // 2
        diversity = 0
        for n in range(1, diversity_matrix.size(0)):
            for m in range(n):
                diversity += diversity_matrix[n, m].cpu().detach().numpy()
                # diversity += diversity_matrix[n, m].item()
        diversity/=num_pairs
        del embeddings
        del diversity_matrix
        torch.cuda.empty_cache()
        return diversity

    def get_results(self):
        results = deepcopy(self.dataset)
        avg_diversity = {"self-SBert" : 0}

        for i, data in enumerate(tqdm(self.dataset)):
            solutions = ["None" if not sol else "\n".join(sol) for sol in data['output']['solution']]
            results[i]["output"]["diversity"] = {}
            if len(solutions)>1:
                diversity = self.calculate_diversity(solutions)
                results[i]["output"]["diversity"]["self-SBert"] = diversity
                avg_diversity["self-SBert"] += diversity
                self.cal_num+=1
            else:
                results[i]["output"]["diversity"]["self-SBert"] = None


        self.print_results(avg_diversity)
        return results

class ErrorEvaluation(BaseEvaluation):
    def __init__(self, dataset: list):
        super().__init__(dataset)
    
    def print_results(self, results_report, error_type_report, subject_report, level_report, analysis_mode, print_mode=None):
        total_results = []
        if print_mode=="only_copy":
            for result in results_report:
                if result not in ["True Positive (TP)", "True Negative (TN)", "False Positive (FP)", "False Negative (FN)"]:
                    total_results.append(str(round(results_report[result]*100, 3)))
            
            for error_type in error_type_report:
                tmp_result = sum(error_type_report[error_type]["preds"]) / len(error_type_report[error_type]["preds"])
                total_results.append(str(round(tmp_result*100, 3)))

            for subject in subject_report:
                if analysis_mode=="f1":
                    tmp_result = f1_score(subject_report[subject]["labels"], subject_report[subject]["preds"])
                elif analysis_mode=="wrong_acc":
                    tmp_list = []
                    for i in range(len(subject_report[subject]["labels"])):
                        if subject_report[subject]["labels"][i]==1:
                            tmp_list.append(subject_report[subject]["preds"][i])
                    tmp_result = sum(tmp_list) / len(tmp_list)
                total_results.append(str(round(tmp_result*100, 3)))

            for level in level_report:
                if analysis_mode=="f1":
                    tmp_result = f1_score(level_report[level]["labels"], level_report[level]["preds"])
                elif analysis_mode=="wrong_acc":
                    tmp_list = []
                    for i in range(len(level_report[level]["labels"])):
                        if level_report[level]["labels"][i]==1:
                            tmp_list.append(level_report[level]["preds"][i])
                    tmp_result = sum(tmp_list) / len(tmp_list)
                total_results.append(str(round(tmp_result*100, 3)))

            # print("\nCopy The Results\n")
            print(", ".join(total_results))
        else:
            print("Main Metric Results")
            for result in results_report:
                if result not in ["True Positive (TP)", "True Negative (TN)", "False Positive (FP)", "False Negative (FN)"]:
                    total_results.append(str(round(results_report[result]*100, 3)))
                    print(f"{result} : {round(results_report[result]*100, 3)}")
            
            print("\nResults by Error Type")
            for error_type in error_type_report:
                tmp_result = sum(error_type_report[error_type]["preds"]) / len(error_type_report[error_type]["preds"])
                total_results.append(str(round(tmp_result*100, 3)))
                print(f"{error_type} : {round(tmp_result*100, 3)}")

            print("\nResults by Subject")
            for subject in subject_report:
                if analysis_mode=="f1":
                    tmp_result = f1_score(subject_report[subject]["labels"], subject_report[subject]["preds"])
                elif analysis_mode=="wrong_acc":
                    tmp_list = []
                    for i in range(len(subject_report[subject]["labels"])):
                        if subject_report[subject]["labels"][i]==1:
                            tmp_list.append(subject_report[subject]["preds"][i])
                    tmp_result = sum(tmp_list) / len(tmp_list)
                total_results.append(str(round(tmp_result*100, 3)))
                print(f"{subject} : {round(tmp_result*100, 3)}")

            print("\nResults by Level")
            for level in level_report:
                if analysis_mode=="f1":
                    tmp_result = f1_score(level_report[level]["labels"], level_report[level]["preds"])
                elif analysis_mode=="wrong_acc":
                    tmp_list = []
                    for i in range(len(level_report[level]["labels"])):
                        if level_report[level]["labels"][i]==1:
                            tmp_list.append(level_report[level]["preds"][i])
                    tmp_result = sum(tmp_list) / len(tmp_list)
                total_results.append(str(round(tmp_result*100, 3)))
                print(f"{level} : {round(tmp_result*100, 3)}")

            print("\nCopy The Results\n")
            print(", ".join(total_results))

    def calculate_metric(self, labels, predictions):
        # Calculate TP, TN, FP, FN
        tn, fp, fn, tp = confusion_matrix(labels, predictions).ravel()
        
        # Calculate F1 Score and Recall
        acc = accuracy_score(labels, predictions)
        f1 = f1_score(labels, predictions)
        pr = precision_score(labels, predictions)
        recall = recall_score(labels, predictions)

        return tn, fp, fn, tp, acc, f1, pr, recall

    def get_results(self):
        # error_type_list = ["calculation_error", "misunderstanding", "logical_error", "redundancy_error", "correct"]
        error_type_list = ["calculation_error", "misunderstanding", "logical_error", "correct"]
        subject_list = ["Algebra", "Prealgebra", "Intermediate Algebra" , "Counting & Probability", "Precalculus", "Number Theory", "Geometry"]
        level_dict = {"1": "Easy", "2": "Easy", "3": "Easy", "4": "Hard", "5": "Hard"}
        error_type_report={}
        subject_report={}
        level_report={
            "Easy": { "labels" : [], "preds" : [] }, 
            "Hard": { "labels" : [], "preds" : [] }, 
        }
        for error_type in error_type_list:
            error_type_report[error_type] = { "labels" : [], "preds" : [] }
        for subject in subject_list:
            subject_report[subject] = { "labels" : [], "preds" : [] }

        ## Label -> 0 : first error가 없음, 1: first error가 있음
        ## Predict_1 -> 0 : 전체 step에 error가 없음, 1: 전체 step 중 에 error가 있음   # First error 지점이 틀려도 error 가 있는 데이터에서 error 가 있다고 생각 되면 맞음
        ## Predict_2 -> 0 : 해당 step에 error가 없음, 1: 해당 step에 error가 있음       # 정확히 First error 지점이 맞아야 함
        labels=[]
        predictions_1=[]; predictions_2=[]
        strict_error_identify = []
        first_error_identify_acc=0
        for data in self.dataset:
            label = 1 if data["major_error_type"]!="correct" else 0
            prediction_1 = 1 if data["output"]["first_error_step"] is not None else 0
            if data["major_error_type"]!="correct":
                first_error_identify = 1 if data["output"]["first_error_step"]==data["error_step"] else 0
            else:
                first_error_identify = 0 if data["output"]["first_error_step"] is not None else 1

            labels.append(label)
            predictions_1.append(prediction_1)
            predictions_2.append(first_error_identify)
            
            if label:   # error solution
                strict_score = 1 if data["output"]["first_error_step"] is not None and data["output"]["first_error_step"] >= data["error_step"] else 0
                strict_error_identify.append(strict_score)

            error_type_report[data["major_error_type"]]["labels"].append(label)
            error_type_report[data["major_error_type"]]["preds"].append(first_error_identify)
            subject_report[data["subject"]]["labels"].append(label)
            subject_report[data["subject"]]["preds"].append(first_error_identify)
            level_report[level_dict[str(data["level"])]]["labels"].append(label)
            level_report[level_dict[str(data["level"])]]["preds"].append(first_error_identify)

            first_error_identify_acc+=first_error_identify

        correct_error_list = []
        wrong_error_list = []
        first_error_list = []
        for lab, pred1, pred2 in zip(labels, predictions_1, predictions_2):
            if lab==1:
                wrong_error_list.append(pred1)
                first_error_list.append(pred2)
            else:
                correct_error_list.append(pred1)

        results_report = {
            "Error Identify Acc (Correct)" : 1 - sum(correct_error_list) / len(correct_error_list),
            "Error Identify Acc (Wrong, strict)" : sum(strict_error_identify) / len(strict_error_identify),
            "Error Identify Acc (Wrong)" : sum(wrong_error_list) / len(wrong_error_list),
            "First Error Identify Acc" : sum(first_error_list) / len(first_error_list)
        }

        # Show all results
        self.print_results(results_report, error_type_report, subject_report, level_report, analysis_mode = "wrong_acc")

        return results_report, error_type_report, subject_report, level_report

class ErrorEvaluationPRM(ErrorEvaluation):
    def __init__(self, dataset: list, threshold:list):
        super().__init__(dataset)
        self.threshold_list = threshold

    def get_results(self):
        error_type_list = ["calculation_error", "misunderstanding", "logical_error", "correct"]
        subject_list = ["Algebra", "Prealgebra", "Intermediate Algebra" , "Counting & Probability", "Precalculus", "Number Theory", "Geometry"]
        level_dict = {"1": "Easy", "2": "Easy", "3": "Easy", "4": "Hard", "5": "Hard"}
        error_type_report={}
        subject_report={}
        level_report={}
        for threshold in self.threshold_list:
            error_type_report[str(threshold)] = {}
            subject_report[str(threshold)] = {}
            level_report[str(threshold)] = {
                "Easy": { "labels" : [], "preds" : [] }, 
                "Hard": { "labels" : [], "preds" : [] }, 
            }
            for error_type in error_type_list:
                error_type_report[str(threshold)][error_type] = { "labels" : [], "preds" : [] }
            for subject in subject_list:
                subject_report[str(threshold)][subject] = { "labels" : [], "preds" : [] }

        ## Label -> 0 : first error가 없음, 1: first error가 있음
        ## Predict_1 -> 0 : 전체 step에 error가 없음, 1: 전체 step 중 에 error가 있음   # First error 지점이 틀려도 error 가 있는 데이터에서 error 가 있다고 생각 되면 맞음
        ## Predict_2 -> 0 : 해당 step에 error가 없음, 1: 해당 step에 error가 있음       # 정확히 First error 지점이 맞아야 함
        labels = []
        predictions_1={}; predictions_2={}
        first_error_identify_acc={}
        strict_error_identify = {}
        for threshold in self.threshold_list:
            predictions_1[str(threshold)]=[]
            predictions_2[str(threshold)]=[]
            strict_error_identify[str(threshold)]=[]
            first_error_identify_acc[str(threshold)]=0

        for data in self.dataset:
            label = 1 if data["major_error_type"]!="correct" else 0
            labels.append(label)
            for threshold in self.threshold_list:
                prediction_1 = 1 if data["output"]["first_error_step"][str(threshold)] is not None else 0
                if data["major_error_type"]!="correct":
                    first_error_identify = 1 if data["output"]["first_error_step"][str(threshold)]==data["error_step"] else 0
                else:
                    first_error_identify = 0 if data["output"]["first_error_step"][str(threshold)] is not None else 1

                predictions_1[str(threshold)].append(prediction_1)
                predictions_2[str(threshold)].append(first_error_identify)

                if label:   # error solution
                    strict_score = 1 if data["output"]["first_error_step"][str(threshold)] is not None and data["output"]["first_error_step"][str(threshold)] >= data["error_step"] else 0
                    strict_error_identify[str(threshold)].append(strict_score)

                error_type_report[str(threshold)][data["major_error_type"]]["labels"].append(label)
                error_type_report[str(threshold)][data["major_error_type"]]["preds"].append(first_error_identify)
                subject_report[str(threshold)][data["subject"]]["labels"].append(label)
                subject_report[str(threshold)][data["subject"]]["preds"].append(first_error_identify)
                level_report[str(threshold)][level_dict[str(data["level"])]]["labels"].append(label)
                level_report[str(threshold)][level_dict[str(data["level"])]]["preds"].append(first_error_identify)

                first_error_identify_acc[str(threshold)]+=first_error_identify

        results_report = {}
        for threshold in self.threshold_list:
            correct_error_list = []
            wrong_error_list = []
            first_error_list = []
            for lab, pred1, pred2 in zip(labels, predictions_1[str(threshold)], predictions_2[str(threshold)]):
                if lab==1:
                    wrong_error_list.append(pred1)
                    first_error_list.append(pred2)
                else:
                    correct_error_list.append(pred1)
            results_report[str(threshold)] = {
                "Error Identify Acc (Correct)" : 1- sum(correct_error_list) / len(correct_error_list),
                "Error Identify Acc (Wrong, strict)" : sum(strict_error_identify[str(threshold)]) / len(strict_error_identify[str(threshold)]),
                "Error Identify Acc (Wrong)" : sum(wrong_error_list) / len(wrong_error_list),
                "First Error Identify Acc" : sum(first_error_list) / len(first_error_list)
            }

            self.print_results(results_report[str(threshold)], error_type_report[str(threshold)], subject_report[str(threshold)], level_report[str(threshold)], analysis_mode = "wrong_acc", print_mode="only_copy")

        return results_report, error_type_report, subject_report, level_report
    

class RewardBenchEvaluation(BaseEvaluation):
    def __init__(self, dataset: list, func_list: list):
        super().__init__(dataset)
        self.func_list = func_list

    def _calculate_final_reward(self, output, func_name):
        if func_name=="min":
            return min(output)
        elif func_name=="max":
            return max(output)
        elif func_name=="prod":
            prod = 1
            for out in output:
                prod*=out
            return prod
        elif func_name=="mean":
            return sum(output)/len(output)
        elif func_name=="mean_logit":
            p = np.array(output)
            logit = np.log(p / (1 - p))
            mean_logit = 1 / (1 + np.exp(-np.mean(logit)))
            return mean_logit
        elif func_name=="mean_odd":
            p = np.array(output)
            odds = p / (1 - p)
            mean_odd = np.maximum(0, np.mean(odds))
            return mean_odd
        elif func_name=="last":
            return output[-1]
        else:
            return None
        
    def get_results(self):
        ### Calculate total results
        tmp_results = {}
        for data in self.dataset:
            final_reward = {}
            for func_name in self.func_list:
                final_reward[func_name] = self._calculate_final_reward(data["output"]["step_scores"], func_name)

            problem_id = data["problem_id"]
            if str(problem_id) in tmp_results.keys():
                tmp_results[str(problem_id)].append({
                    "final_reward" : final_reward,
                    "solution_type": data["solution_type"]
                })
            else:
                tmp_results[str(problem_id)] = [{
                    "final_reward" : final_reward,
                    "solution_type": data["solution_type"]
                }]

        ### Compare the reward of chosen and rejected solution
        compare_results = {}
        for func_name in self.func_list:
            compare_results[func_name] = []
            for problem in tmp_results:
                chosen_reward=0; rejected_reward=0
                for d in tmp_results[problem]:
                    if d["solution_type"]=="chosen":
                        chosen_reward = d["final_reward"][func_name]
                    else:   # rejected
                        rejected_reward = d["final_reward"][func_name]
                if chosen_reward > rejected_reward:
                    compare_results[func_name].append(1)
                else:
                    compare_results[func_name].append(0)
        
        print("### Reward Accuracy ###")
        for func_name in self.func_list:
            print(round(100*sum(compare_results[func_name])/len(compare_results[func_name]),3))

        return compare_results

class BoNEvaluation(BaseEvaluation):
    def __init__(self, dataset: list, eval_type: str=None, prm_func: str=None):
        super().__init__(dataset)
        self.eval_type = eval_type
        self.prm_func = prm_func
    
    def _calculate_final_reward(self, output, func_name):
        if len(output)==0:
            return 0.5
        if func_name=="min":
            return min(output)
        elif func_name=="max":
            return max(output)
        elif func_name=="prod":
            prod = 1
            for out in output:
                prod*=out
            return prod
        elif func_name=="geometric_mean":
            g_mean = 1
            for out in output:
                g_mean*=out
            return g_mean**(1/len(output))
        elif func_name=="mean":
            return sum(output)/len(output)
        elif func_name=="mean_logit":
            p = np.array(output)
            logit = np.log(p / (1 - p))
            mean_logit = 1 / (1 + np.exp(-np.mean(logit)))
            return mean_logit
        elif func_name=="mean_odd":
            p = np.array(output)
            odds = p / (1 - p)
            mean_odd = np.maximum(0, np.mean(odds))
            return mean_odd
        elif func_name=="last":
            return output[-1]
        else:
            return None
        
    def check_correctness(self, output, answer):
        if len(output) == 0:
            try:
                output = output.split("boxed")[1]
            except IndexError:
                return False
        try:
            return latex_answer_check(output, answer, eval_policy="aggressive", extract_policy="flex", split=None)
        except Exception as e:
            print(f"Error checking correctness: {e}")
            return False

    def self_consistency(self, answer_list):
        index_dict = {}
        tmp_answer_list = []
        for idx, answer in enumerate(answer_list):
            is_unique = True
            idx_list = [ i for i in index_dict]
            if len(answer)==0:
                answer = "None"     ### for empty answer
            for k, tmp_answer in enumerate(tmp_answer_list):
                if latex_answer_check(answer, tmp_answer, eval_policy="aggressive", extract_policy="flex", split=None):
                    index_key = idx_list[k]
                    index_dict[index_key] += 1
                    is_unique = False
                    break
            if is_unique:
                tmp_answer_list.append(answer)
                index_dict[str(idx)] = 1

        max_value = max(index_dict.values())
        keys_with_max_value = [key for key, value in index_dict.items() if value == max_value]
        if len(keys_with_max_value)>1:
            # return random.choice(keys_with_max_value)
            return keys_with_max_value[0]
        else:
            return keys_with_max_value[0]
        
    def extract_top_response(self, answer_list, scores):
        if self.eval_type=="self_consistency":
            max_voting_key = self.self_consistency(answer_list)
            top_response = answer_list[int(max_voting_key)]
        elif self.eval_type=="best_of_n":
            max_score = max(scores)
            max_score_index = [i for i, score in enumerate(scores) if score == max_score]
            if len(max_score_index)>1:
                output_list = [answer_list[i] for i in max_score_index]
                max_voting_key = self.self_consistency(output_list)
                top_response = output_list[int(max_voting_key)]
            else:
                top_response = answer_list[max_score_index[0]]
        elif self.eval_type=="best_of_n_prm":
            score_list = [self._calculate_final_reward(score, func_name=self.prm_func) for score in scores]
            max_score = max(score_list)
            max_score_index = [i for i, score in enumerate(score_list) if score == max_score]
            if len(max_score_index)>1:
                output_list = [answer_list[i] for i in max_score_index]
                max_voting_key = self.self_consistency(output_list)
                top_response = output_list[int(max_voting_key)]
            else:
                top_response = answer_list[max_score_index[0]]
        else:
            raise ValueError("Invalid eval type")
        return top_response

    def get_results(self):
        # data_source = ["MATH", "agieval_gaokao", "agieval_sat"]
        data_source = ["MATH"]
        if self.eval_type=="self_consistency":
            sampling_num_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
        else:
            if len(self.dataset[0]["score"])==16:
                sampling_num_list = [2, 4, 8, 16]  # generative reward model
            else:
                sampling_num_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]

        total_acc = {}
        for sampling_num in sampling_num_list:
            total_acc[str(sampling_num)] = {}
            for source in data_source:
                total_acc[str(sampling_num)][source] = []

        for data in tqdm(self.dataset):
            for sampling_num in sampling_num_list:
                # sampling the top response (reward base)
                if "score" in data:
                    top_response = self.extract_top_response(data["output"]["answer"][:sampling_num], data["score"][:sampling_num])
                else:
                    top_response = self.extract_top_response(data["output"]["answer"][:sampling_num], None)
                # checking correctness
                result = self.check_correctness(top_response, data["answer"])
                total_acc[str(sampling_num)][data["data_source"]].append(result)

        print("######## Results ########")
        for sampling_num in sampling_num_list:
            for source in data_source: 
                print(round(100*sum(total_acc[str(sampling_num)][source])/len(total_acc[str(sampling_num)][source]),3), end=", ")
            print()
        return self.dataset

class MartiniEvaluation(BaseEvaluation):
    def __init__(self, dataset: list, func_list: list, except_model: str=None):
        super().__init__(dataset)
        self.func_list = func_list
        self.except_model = except_model

    def _calculate_final_reward(self, output, func_name):
        if func_name=="min":
            return min(output)
        elif func_name=="max":
            return max(output)
        elif func_name=="prod":
            prod = 1
            for out in output:
                prod*=out
            return prod
        elif func_name=="geometric_mean":
            g_mean = 1
            for out in output:
                g_mean*=out
            return g_mean**(1/len(output))
        elif func_name=="mean":
            return sum(output)/len(output)
        elif func_name=="mean_logit":
            p = np.array(output)
            logit = np.log(p / (1 - p))
            mean_logit = 1 / (1 + np.exp(-np.mean(logit)))
            return mean_logit
        elif func_name=="mean_odd":
            p = np.array(output)
            odds = p / (1 - p)
            mean_odd = np.maximum(0, np.mean(odds))
            return mean_odd
        elif func_name=="last":
            return output[-1]
        else:
            return None
        
    def _calculate_MRR(self, chosen_score, rejected_score, mode="mrr"):
        tmp_list = []
        for rejected_s in rejected_score:
            tmp_list.append((rejected_s, "rej"))
        for chosen_s in chosen_score:
            tmp_list.append((chosen_s, "cho"))
        sorted_list = sorted(tmp_list, key=lambda x: (x[0], x[1]), reverse=True)
        first_chosen_idx = None
        for i, s in enumerate(sorted_list):
            if s[1]=="cho":
                first_chosen_idx = i+1
                break
        if mode=="mrr":
            return 1/first_chosen_idx
        elif mode=="rank":
            return (len(tmp_list) - first_chosen_idx)/(len(tmp_list)-1)
        
    def _calculate_reward_diff(self, chosen_score, rejected_score):
        reward_diff_list = []
        tmp_list = chosen_score + rejected_score
        min_reward = min(tmp_list)
        max_reward = max(tmp_list)
        if min_reward==max_reward:
            return 0
        for chosen_s in chosen_score:
            chosen_s_norm = (chosen_s-min_reward)/(max_reward-min_reward)   # min max scaling
            for rejected_s in rejected_score:
                rejected_s_norm = (rejected_s-min_reward)/(max_reward-min_reward)  # min max scaling
                reward_diff_list.append(chosen_s_norm - rejected_s_norm)
        # return sum(reward_diff_list)/len(reward_diff_list)
        return sum(reward_diff_list)
    
    def processing_results(self, data):
        hard_model_tmp = ["gpt-4o-2024-05-13", "claude-3-sonnet-20240229", "meta-llama/Meta-Llama-3-70B-Instruct", "google/gemma-2-27b-it", "microsoft/Phi-3-medium-4k-instruct", "meta-llama/Meta-Llama-3-8B-Instruct"]
        easy_model_tmp = ["modified_with_GPT-4", "gpt-3.5-turbo-0125", "mistralai/Mixtral-8x7B-Instruct-v0.1", "deepseek-ai/DeepSeek-V2-Lite-Chat", "Qwen/Qwen1.5-7B-Chat", "google/gemma-7b-it", "WizardLMTeam/WizardMath-7B-V1.1", "peiyi9979/mistral-7b-sft"]
        hard_model = [model for model in hard_model_tmp if model!=self.except_model]
        easy_model = [model for model in easy_model_tmp if model!=self.except_model]
        
        chosen_score = {func_name: [] for func_name in self.func_list}
        rejected_score = {func_name: [] for func_name in self.func_list}
        easy_data = {func_name: [] for func_name in self.func_list}
        hard_data = {func_name: [] for func_name in self.func_list}
        
        for func_name in self.func_list:
            if func_name == "pairwise":
                win_result = {
                    "total" : [],
                    "easy" : [],
                    "hard": []
                }
            for idx, (data_source, score) in enumerate(zip(data["solution_reference"], data["score"])):
                if func_name == "pairwise":
                    win_result["total"].append(score==data["chosen_position"][idx])
                    if data_source in easy_model:
                        win_result["easy"].append(score==data["chosen_position"][idx])
                    elif data_source in hard_model:
                        win_result["hard"].append(score==data["chosen_position"][idx])
                else:
                    if func_name == "normal":
                        tmp_score = score
                    else:
                        tmp_score =  self._calculate_final_reward(score, func_name=func_name)
                    if data_source == "human_to_GPT-4":
                        chosen_score[func_name].append(tmp_score)
                    else:
                        if data_source!=self.except_model:
                            rejected_score[func_name].append(tmp_score)
                    if data_source in easy_model:
                        easy_data[func_name].append(tmp_score)
                    elif data_source in hard_model:
                        hard_data[func_name].append(tmp_score)
            if func_name == "pairwise":
                # chosen > rejected -> reward_rej = 1, chosen < rejected  -> reward_rej : largest reward
                largest_reward = len(win_result["total"])+2
                chosen_score[func_name].append(sum(win_result["total"])+1)
                rejected_score[func_name] = [1 if x else largest_reward for x in win_result["total"]]
                easy_data[func_name] = [1 if x else largest_reward for x in win_result["easy"]]
                hard_data[func_name] = [1 if x else largest_reward for x in win_result["hard"]]
                
        return {
            "problem" : data["problem"],
            "chosen_score" : chosen_score,
            "rejected_score" : rejected_score,
            "rejected_score_easy" : easy_data,
            "rejected_score_hard" : hard_data,
        }
        
    def get_results(self):
        final_data_score = [self.processing_results(data) for data in self.dataset]
        # print(final_data_score[0])
        
        ### Calculate total results
        # metric_list = ["reward Acc", "reward MRR", "reward Acc (Easy)", "reward Acc (Hard)", "reward MRR (Easy)", "reward MRR (Hard)"]
        # metric_list = ["reward Acc", "reward MRR", "reward rank", "reward Acc (Easy)", "reward Acc (Hard)", "reward MRR (Easy)", "reward MRR (Hard)", "reward rank (Easy)", "reward rank (Hard)"]
        # metric_list = ["reward Acc", "reward MRR", "reward rank", "reward diff", "reward Acc (Easy)", "reward Acc (Hard)", "reward MRR (Easy)", "reward MRR (Hard)", "reward rank (Easy)", "reward rank (Hard)"]
        metric_list = ["reward Acc", "reward MRR", "reward rank", "reward diff"]
        final_results = {}
        for func_name in self.func_list:
            final_results[func_name] = {metric_name: [] for metric_name in metric_list}
            final_results[func_name]["reward Acc (w/ tie)"] = []
            for data_score in final_data_score:
                final_results[func_name]["reward Acc (w/ tie)"].append(min(data_score["chosen_score"][func_name]) >= max(data_score["rejected_score"][func_name]))
                # if min(data_score["chosen_score"][func_name]) > max(data_score["rejected_score"][func_name]):
                #     print(data_score)
                final_results[func_name]["reward Acc"].append(min(data_score["chosen_score"][func_name]) > max(data_score["rejected_score"][func_name]))
                # final_results[func_name]["reward Acc (Easy)"].append(min(data_score["chosen_score"][func_name]) > max(data_score["rejected_score_easy"][func_name]))
                # final_results[func_name]["reward Acc (Hard)"].append(min(data_score["chosen_score"][func_name]) > max(data_score["rejected_score_hard"][func_name]))
                final_results[func_name]["reward MRR"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score"][func_name]))
                # final_results[func_name]["reward MRR (Easy)"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score_easy"][func_name]))
                # final_results[func_name]["reward MRR (Hard)"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score_hard"][func_name]))
                final_results[func_name]["reward rank"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score"][func_name], mode="rank"))
                # final_results[func_name]["reward rank (Easy)"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score_easy"][func_name], mode="rank"))
                # final_results[func_name]["reward rank (Hard)"].append(self._calculate_MRR(data_score["chosen_score"][func_name], data_score["rejected_score_hard"][func_name], mode="rank"))
                final_results[func_name]["reward diff"].append(self._calculate_reward_diff(data_score["chosen_score"][func_name], data_score["rejected_score"][func_name]))

        print("### Total Results ###")
        for func_name in self.func_list:
            if func_name=="normal":
                print("w/ tie Acc: ", round(100*sum(final_results[func_name]["reward Acc (w/ tie)"])/len(final_results[func_name]["reward Acc (w/ tie)"]),3))
            for metric in final_results[func_name]:
                if metric=="reward Acc (w/ tie)":
                    pass
                elif metric!="reward diff":
                    print(round(100*sum(final_results[func_name][metric])/len(final_results[func_name][metric]),3), end=", ")
                else:
                    if func_name=="pairwise":
                        print("None", end=", ")
                    else:
                        print(round(sum(final_results[func_name][metric])/len(final_results[func_name][metric]),3), end=", ")
            print()
                    

        return final_results